{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3-dXXfV_pGa8"
      },
      "outputs": [],
      "source": [
        "# Copyright 2023 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xv3oVInApMae"
      },
      "source": [
        "# Visualizing embedding similarity from text documents using t-SNE plots\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/embeddings/embedding-similarity-visualization.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Run in Colab\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/embeddings/embedding-similarity-visualization.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://raw.githubusercontent.com/primer/octicons/refs/heads/main/icons/mark-github-24.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/embeddings/embedding-similarity-visualization.ipynb\">\n",
        "      <img src=\"https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
        "    </a>\n",
        "  </td>\n",
        "</table>\n",
        "\n",
        "<div style=\"clear: both;\"></div>\n",
        "\n",
        "<b>Share to:</b>\n",
        "\n",
        "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/embeddings/embedding-similarity-visualization.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/embeddings/embedding-similarity-visualization.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/embeddings/embedding-similarity-visualization.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/embeddings/embedding-similarity-visualization.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/embeddings/embedding-similarity-visualization.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
        "</a>            "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "af3e9cd6b94d"
      },
      "source": [
        "| | |\n",
        "|-|-|\n",
        "|Author(s) | [Gabe Rives-Corbett](https://github.com/grivescorbett) |"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a41adaaac995"
      },
      "source": [
        "This notebook demonstrates how vector similarity is relevant to LLM-generated embeddings. You will embed a collection of labelled documents and then plot the embeddings on a two-dimensional t-SNE plot to observe how similar documents tend to cluster together based on their embeddings."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2tdCmJnzpkmi"
      },
      "source": [
        "## Getting started"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ODKbeVNopmrM"
      },
      "source": [
        "### Install libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bM6F3k69UdCE"
      },
      "outputs": [],
      "source": [
        "%pip install --upgrade google-genai scikit-learn pandas seaborn"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3s04fztzqYek"
      },
      "source": [
        "### Authenticate your notebook environment (Colab only)\n",
        "\n",
        "If you are running this notebook on Google Colab, you will need to authenticate your environment. To do this, run the new cell below. This step is not required if you are using [Vertex AI Workbench](https://cloud.google.com/vertex-ai-workbench)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a4mvUNbzqZH8"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "\n",
        "if \"google.colab\" in sys.modules:\n",
        "    from google.colab import auth\n",
        "\n",
        "    auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "06fe990961b7"
      },
      "source": [
        "### Set Google Cloud project information and initialize Vertex AI SDK\n",
        "\n",
        "To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).\n",
        "\n",
        "Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d5f508350c34"
      },
      "outputs": [],
      "source": [
        "# Use the environment variable if the user doesn't provide Project ID.\n",
        "import os\n",
        "\n",
        "from google import genai\n",
        "\n",
        "PROJECT_ID = \"[your-project-id]\"  # @param {type: \"string\", placeholder: \"[your-project-id]\", isTemplate: true}\n",
        "if not PROJECT_ID or PROJECT_ID == \"[your-project-id]\":\n",
        "    PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n",
        "\n",
        "LOCATION = os.environ.get(\"GOOGLE_CLOUD_REGION\", \"us-central1\")\n",
        "\n",
        "client = genai.Client(vertexai=True, project=PROJECT_ID, location=LOCATION)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KYN5w2llqRiN"
      },
      "source": [
        "### Import libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nMZ9npzcVDn6"
      },
      "outputs": [],
      "source": [
        "import re\n",
        "\n",
        "from google.api_core import retry\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import seaborn as sns\n",
        "from sklearn.datasets import fetch_20newsgroups\n",
        "from sklearn.manifold import TSNE\n",
        "from sklearn.model_selection import train_test_split\n",
        "from tqdm.auto import tqdm\n",
        "\n",
        "tqdm.pandas()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gmck9vqZqj-U"
      },
      "source": [
        "## Fetch and clean the data\n",
        "\n",
        "In this example, you will use the open source [20 Newsgroups](http://qwone.com/~jason/20Newsgroups/) dataset, a collection of approximately 20,000 newsgroup documents, partitioned (nearly) evenly across 20 different newsgroups"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zgF9FsB5aTc9"
      },
      "outputs": [],
      "source": [
        "newsgroups = fetch_20newsgroups(\n",
        "    categories=[\"comp.graphics\", \"sci.space\", \"sci.med\", \"rec.sport.hockey\"]\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XuBYC5Q-ac7K"
      },
      "outputs": [],
      "source": [
        "raw_data = pd.DataFrame(\n",
        "    {\n",
        "        \"text\": newsgroups.data,\n",
        "        \"target\": [newsgroups.target_names[x] for x in newsgroups.target],\n",
        "    }\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CexQC3-0rEsg"
      },
      "source": [
        "Because of the 8k input token limit, in this example you will exclude all documents that have a length outside this limit.\n",
        "\n",
        "Even though tokens typically are >=1 characters, for simplicity, you can just filter for documents that have <= 8000 _characters_."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FBjP-_Voadij"
      },
      "outputs": [],
      "source": [
        "raw_data = raw_data.loc[raw_data[\"text\"].str.len() <= 8000]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t51GoVa4rMGM"
      },
      "source": [
        "Subsample the dataset into 500 data points, stratified on the label"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BDEzf8eucexS"
      },
      "outputs": [],
      "source": [
        "x_subsample, _, y_subsample, _ = train_test_split(\n",
        "    raw_data[\"text\"], raw_data[\"target\"], stratify=raw_data[\"target\"], train_size=500\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FAYckCCcrU8Z"
      },
      "source": [
        "Clean out the text removing by emails, names, etc. This will help improve the data that will then be converted into embeddings."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "z7SMLgsTddIW"
      },
      "outputs": [],
      "source": [
        "x_subsample = [re.sub(r\"[\\w\\.-]+@[\\w\\.-]+\", \"\", d) for d in x_subsample]  # Remove email\n",
        "x_subsample = [re.sub(r\"\\([^()]*\\)\", \"\", d) for d in x_subsample]  # Remove names\n",
        "x_subsample = [d.replace(\"From: \", \"\") for d in x_subsample]  # Remove \"From: \"\n",
        "x_subsample = [\n",
        "    d.replace(\"\\nSubject: \", \"\") for d in x_subsample\n",
        "]  # Remove \"\\nSubject: \""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2ak41eeiddyM"
      },
      "outputs": [],
      "source": [
        "df = pd.DataFrame({\"text\": x_subsample, \"target\": list(y_subsample)})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "easflrmcrbFw"
      },
      "source": [
        "You now have 500 data points roughly evenly distributed across the categories:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JEzHxV_9d4xl"
      },
      "outputs": [],
      "source": [
        "df[\"target\"].value_counts()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-VeSVs5hriMq"
      },
      "source": [
        "## Create and visualize the embeddings using a t-SNE plot\n",
        "\n",
        "Load the text embedding model from Vertex AI ([documentation](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-embeddings)).\n",
        "\n",
        "Since we are using these embeddings for visualization, we will set the [task type](https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types) to clustering."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jte4vh6oeR9X"
      },
      "outputs": [],
      "source": [
        "MODEL_ID = \"text-embedding-005\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "csrvfLqRe0qC"
      },
      "outputs": [],
      "source": [
        "# Retrieve embeddings from the specified model with retry logic\n",
        "\n",
        "\n",
        "def get_embeddings():\n",
        "    @retry.Retry(timeout=300.0)\n",
        "    def embed_fn(contents: str) -> list[float]:\n",
        "        response = client.models.embed_content(\n",
        "            model=MODEL_ID,\n",
        "            contents=contents,\n",
        "        )\n",
        "        return response.embeddings[0].values\n",
        "\n",
        "    return embed_fn"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nQX_nVIyrzCl"
      },
      "source": [
        "Create the embeddings. This may take a minute or two."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WUySDL6ffQDt"
      },
      "outputs": [],
      "source": [
        "df[\"embeddings\"] = df[\"text\"].progress_apply(get_embeddings())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HkKoUkLuffsS"
      },
      "outputs": [],
      "source": [
        "df.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YrwmPDv4r2oc"
      },
      "source": [
        "The vectors generate by our model are 768 dimensions, and so visualizing across 768 dimensions is impossible. Instead, you can use [t-SNE](https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding) to reduce to 2 dimensions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AzMvneXyhQcO"
      },
      "outputs": [],
      "source": [
        "tsne = TSNE(random_state=0, n_iter=1000)\n",
        "tsne_results = tsne.fit_transform(\n",
        "    np.array(df[\"embeddings\"].to_list(), dtype=np.float32)\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Bf6NxWXuhS3q"
      },
      "outputs": [],
      "source": [
        "df_tsne = pd.DataFrame(tsne_results, columns=[\"TSNE1\", \"TSNE2\"])\n",
        "df_tsne[\"target\"] = df[\"target\"]  # Add labels column from df_train to df_tsne"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OaGVCe64hhjU"
      },
      "outputs": [],
      "source": [
        "df_tsne.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B8vlRIWdsFz0"
      },
      "source": [
        "Plot the data points. It should now be visually clear how the documents from the same newsgroup show up close to each other in the vector space with text embeddings."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aPgAMTklhkXV"
      },
      "outputs": [],
      "source": [
        "fig, ax = plt.subplots(figsize=(8, 6))  # Set figsize\n",
        "sns.set_style(\"darkgrid\", {\"grid.color\": \".6\", \"grid.linestyle\": \":\"})\n",
        "sns.scatterplot(data=df_tsne, x=\"TSNE1\", y=\"TSNE2\", hue=\"target\", palette=\"hls\")\n",
        "sns.move_legend(ax, \"upper left\", bbox_to_anchor=(1, 1))\n",
        "plt.title(\"Scatter plot of news using t-SNE\")\n",
        "plt.xlabel(\"TSNE1\")\n",
        "plt.ylabel(\"TSNE2\")\n",
        "plt.axis(\"equal\")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "embedding-similarity-visualization.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
